#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 17 11:44:57 2024

@author: jcfq2
"""

import os

jwst_dir='/Users/jcfq2/data/observations/jwst'

os.chdir(jwst_dir)
import matplotlib.colors as colours
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import _pickle as cPickle
from skimage.transform import rescale

import scipy.ndimage as ndimage
from scipy.interpolate import griddata
from scipy.interpolate import RegularGridInterpolator

import numpy as np


phaseshift=-90-35


# =============================================================================
# In[0] Set up a color table
# =============================================================================

colors2 = plt.cm.gist_rainbow_r(np.linspace(0.25, 1, 220))
colors1 = plt.cm.gnuplot2(np.linspace(0, 0.23, 32))

# combine them and build a new colormap
colors = np.vstack((colors1, colors2))

colors[125:131]=[0.5,0.5,0.5,0.8]
mymap = colours.LinearSegmentedColormap.from_list('my_colormap', colors)


plotcolors= mymap(np.linspace(0, 1, 256))


# =============================================================================
# =============================================================================
# %%   Load all the vectors from Chowdhury et al., 2023
# =============================================================================
# =============================================================================


"""
Bin 4 Gaussian fits
"""
filedir = '/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/saturnproject/Gaussian_fits/magnetic_phase/10-pixel_smooth/'

filename1='northern_bin1_315_45_order19.pickle'
filename2='northern_bin2_45_135_order19.pickle'
filename3='northern_bin3_135_225_order19.pickle'
filename4='northern_bin4_225_315_order19.pickle'


filepath = filedir+filename1
with open(filepath, 'rb') as opened_file:
    bin0 = cPickle.load(opened_file)

filepath = filedir+filename2
with open(filepath, 'rb') as opened_file:
    bin90 = cPickle.load(opened_file)

filepath = filedir+filename3
with open(filepath, 'rb') as opened_file:
    bin180 = cPickle.load(opened_file)

filepath = filedir+filename4
with open(filepath, 'rb') as opened_file:
    bin270 = cPickle.load(opened_file)


R=75000
c=300000


int0 = bin0[0][:,:,0]
pos0 = bin0[1][:,:,0]
int90 = bin90[0][:,:,0]
pos90 = bin90[1][:,:,0]
int180 = bin180[0][:,:,0]
pos180 = bin180[1][:,:,0]
int270 = bin270[0][:,:,0]
pos270 = bin270[1][:,:,0]

int_err = bin0[0][:,:,1]
pos_err = bin0[1][:,:,1]

vel0=pos0*1/R*(-c)
vel90=pos90*1/R*(-c)
vel180=pos180*1/R*(-c)
vel270=pos270*1/R*(-c)


vel0_180=vel0-vel180
vel90_270=vel90-vel270

vel0_180[vel0_180>4]=0
vel0_180[vel0_180<-4]=0
vel90_270[vel90_270>4]=0
vel90_270[vel90_270<-4]=0



rs_vel0_180 = rescale(vel0_180, (0.74486,1), anti_aliasing=True)
rs_vel90_270 = rescale(vel90_270, (0.74486,1), anti_aliasing=True)


s_vel0_180 = rs_vel0_180[1:,100-44:100+44]
s_vel90_270 = rs_vel90_270[1:,100-44:100+44]

phase=360-45

def cart2pol(x,y):
    rho = np.sqrt(x**2 + y ** 2)
    phi = np.arctan2(y,x)
    return(rho,phi)

def pol2cart(rho,phi):
    x = rho * np.cos(phi)
    y = rho * np.sin(phi)
    return(x,y)


def phase_shift(vel0_180,vel90_270,phase):
    [mag, angle] = cart2pol(vel90_270,vel0_180)
    newangle = (angle + np.radians(phase))
    [newvel90_270,newvel0_180] = pol2cart(mag,newangle)
    return(newvel0_180,newvel90_270)


rvel0_180 = ndimage.rotate(s_vel0_180, -phase, mode = 'constant')
rvel90_270 = ndimage.rotate(s_vel90_270, -phase, mode = 'constant')


coi_x = rvel0_180[0,:].size/2
coi_y = rvel0_180[:,0].size/2


rrvel0_180=rvel0_180[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]
rrvel90_270=rvel90_270[int(coi_y)-44:int(coi_y)+44,int(coi_x)-44:int(coi_x)+44]



rrrvel0_180, rrvel90_270 = phase_shift(rrvel0_180,rrvel90_270,phase)


x=np.linspace(0,89,88)
y=np.linspace(0,89,88)
xv, yv = np.meshgrid(x, y)


xv = (xv-44)/44*4.4
yv = (yv-44)/44*1.1

xvv=np.zeros((22,22)) 
yvv=np.zeros((22,22)) 
velx=np.zeros((22,22)) 
vely=np.zeros((22,22)) 
for xxx in range(22):
    for yyy in range(22):
        xvv[xxx,yyy] = np.nanmean(xv[xxx*4:xxx*4+4,yyy*4:yyy*4+4])
        yvv[xxx,yyy] = np.nanmean(yv[xxx*4:xxx*4+4,yyy*4:yyy*4+4])
        velx[xxx,yyy] = np.nanmean(rrrvel0_180[xxx*4:xxx*4+4,yyy*4:yyy*4+4])
        vely[xxx,yyy] = np.nanmean(rrvel90_270[xxx*4:xxx*4+4,yyy*4:yyy*4+4])

xvv[:,0:3]=np.nan
xvv[:,-2:]=np.nan

yposcirc2 = 2.5*np.sin(np.radians(89+180))    
xposcirc2 = 2.5*np.cos(np.radians(89+180))    * 1.1 / 4.4
yposcirc1 = 4.0*np.sin(np.radians(20+180))    
xposcirc1 = 4.0*np.cos(np.radians(20+180))    * 1.1 / 4.4

prop = dict(arrowstyle="-|>,head_width=0.2,head_length=0.4",
        shrinkA=0,shrinkB=0,facecolor='black', edgecolor='black')  






# =============================================================================
# =============================================================================
# %%  Load the data and vectors from Smith 2011 - using conversions from Pcurrent and Pw
# =============================================================================
# =============================================================================




def rotate4(input):
    out = np.fliplr( np.rot90(input,3) )
    return out



# Smith et al 2011

# if keyword_set(arrowcol) then arrowc = arrowcol else 
arrowcol = 255
# ; give each dimension a name that makes sense
nn = 20
nm = 91
nl = 36
nt = 1
nsel = nn-1
# nsel=8

# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD CURRENTS 
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

no_sets =12
filename = '/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/smith2010-data/connie.060currents'


# ; print message saying that at least we are giving this a shot.
print('Loading.... ',filename)



# ; define the appropriate shape of array to take data
nnns = nn*no_sets

connie_data=np.loadtxt(filename,skiprows = 1)


data = np.zeros((no_sets,nn,nm,nl))

count = 0
for m in range(nm):
   for n in range(nn):
      for k in range(no_sets):
         
         data[k,n,m,:] = connie_data[count,:]
         count = count + 1
         

jx = np.zeros((nm,nl+1))
jy = np.zeros((nm,nl+1))
jz = np.zeros((nm,nl+1))

# jx[:,0:nl] = data[5,nsel,:,0:nl]
# jy[:,0:nl] = data[6,nsel,:,0:nl]
# jz[:,0:nl] = data[9,nsel,:,0:nl]

pc = data[0,nsel,:,0:nl]

jx[:,0:nl] = data[7,nsel,:,0:nl]
jy[:,0:nl] = data[8,nsel,:,0:nl]
jz[:,0:nl] = data[10,nsel,:,0:nl]

jx[:,nl] = jx[:,0]
jy[:,nl] = jy[:,0]
jz[:,nl] = jz[:,0]

jx = rotate4(jx)*1e9
jy = rotate4(jy)*1e9
jz = rotate4(jz)*1e9

# Remove zonal (longitudinal) mean — the original code seems to have placeholders
# The actual intended behavior is likely subtracting zonal mean for each latitude
jx_mean = np.mean(jx[0:nl,:], axis=0, keepdims=True)
jy_mean = np.mean(jy[0:nl,:], axis=0, keepdims=True)
jz_mean = np.mean(jz[0:nl,:], axis=0, keepdims=True)

# Subtract the zonal mean (replicate(1,nl+1) mimics broadcasting the mean)
jxper = jx - np.tile(jx_mean, (nl + 1,1))
jyper = jy - np.tile(jy_mean, (nl + 1,1))
jzper = jz - np.tile(jz_mean, (nl + 1,1))



lon1d = 360 * np.arange(nl + 1) / float(nl)
lat1d = 180 * (np.arange(nm) / float(nm - 1) - 0.5)

# Create 2D versions (like meshgrid-style)
lon2d = rotate4(np.tile(lon1d, (nm, 1)) )   # shape (nm, nl+1)
lat2d = rotate4(np.tile(lat1d[:, np.newaxis], (1, nl + 1)) ) # shape (nm, nl+1)





# ; we want to take the northern data and map the coordinates to
# ; an x-y grid as viewed from above. then try contour plotting this

# ; first shot - scale equal latitude spacing as equal radial spacing
# ; there is probably a name for this sort of projection which i
# ; could look up if i could be arsed!


th2d = 90-lat2d
ph2drad = lon2d*np.pi/180.0
     

x = th2d*np.cos(ph2drad)
y = th2d*np.sin(ph2drad)



jxx =  jxper*np.cos(ph2drad) - jyper*np.sin(ph2drad)
jyy =  jxper*np.sin(ph2drad) + jyper*np.cos(ph2drad)
jxxyytot = np.sqrt(jxx**2 + jyy**2)
# jxxsc = jxx/jxxyytot
# jyysc = jyy/jxxyytot
with np.errstate(divide='ignore', invalid='ignore'):
    jxxsc = np.divide(jxx, jxxyytot, where=jxxyytot!=0)
    jyysc = np.divide(jyy, jxxyytot, where=jxxyytot!=0)



# ; second shot - scale radial dimension and latitudinal component of arrows
# ; as if we were looking at the planet from above. To do this we need to
# ; scale the meridional component down by a factor of costheta

# ; - - - - - - this block of code scales arrows to same length and then
# ; applies scaling for orthographic projection. Arrows close to equator
# ; will look 'squashed' producing impression of spherical geometry

minlat = 0.0


jxn = jxper[:,45:]
jyn = jyper[:,45:]
lon2dN = lon2d[:,45:]
lat2dN = lat2d[:,45:]

jtotn = np.sqrt(jxn**2 + jyn**2)

jsc = np.max(jtotn)
print('MAXIMUM total j for scaling = ',jsc)
jxnsc = jxn*np.cos(th2d[:,45:]*np.pi/180)/jsc
jynsc = jyn/jsc



jxx =  jxnsc*np.cos(ph2drad[:,45:]) - jynsc*np.sin(ph2drad[:,45:])
jyy =  jxnsc*np.sin(ph2drad[:,45:]) + jynsc*np.cos(ph2drad[:,45:])

rad2d = np.sin(np.pi*th2d/180)
x = rad2d*np.cos(ph2drad)
y = rad2d*np.sin(ph2drad)
fac = 0.03


# ; third shot - interpolate the scaled arrows from last bit (jxx, jyy) onto a cartesian
# ; grid to provide even spacing...

ngr = 21
maxrad = np.sin(np.pi*(90-minlat)/180.0)
# xgr1d = maxrad*(np.arange(2*ngr+1)-ngr)/float(ngr)
# ygr1d = maxrad*(np.arange(2*ngr+1)-ngr)/float(ngr)
# xgr = xgr1d#replicate(1,2*ngr+1)
# ygr = replicate(1,2*ngr+1)#ygr1d

# Create 1D grids from -ngr to ngr, normalized by ngr, scaled by maxrad
xgr1d = maxrad * (np.arange(2*ngr + 1) - ngr) / float(ngr)
ygr1d = maxrad * (np.arange(2*ngr + 1) - ngr) / float(ngr)

# Create 2D grid arrays for x and y (meshgrid style)
xgr, ygr = np.meshgrid(xgr1d, ygr1d)

# Convert xgr, ygr to polar coords (phi, colatitude)
phgr = np.arctan2(ygr, xgr)  # atan2 gives angle between -pi and pi
# Shift negative angles to [0, 2pi]
phgr[phgr < 0] += 2 * np.pi
# Convert phi to longitude indices scale (0 to nl)
phgrind = nl * phgr / (2 * np.pi) 
# Compute colatitude in degrees
colatgr = np.arcsin(np.sqrt(xgr**2 + ygr**2)) * 180 / np.pi

latgr = 90.0 - colatgr
longgr = 360 * phgr / (2 * np.pi) 




# Step 2: Flatten the target lat/lon positions into (N, 2)
# target_points = np.column_stack((latgr.ravel(), (longgr.ravel()-phaseshift) % 360   ))  # shape (43*43, 2)
target_points = np.column_stack((latgr.ravel(), (longgr.ravel() )))  # shape (43*43, 2)
 
# Step 3: Flatten the source grid and values
source_points = np.column_stack((lat2dN.ravel(), lon2dN.ravel()))


source_values = jxx.ravel()
# Step 4: Interpolate
jx_interp_flat = griddata(source_points, source_values, target_points, method='linear')
# Step 5: Reshape interpolated values back to 2D
jxxgr = jx_interp_flat.reshape(latgr.shape)  # shape (43, 43)



source_values = jyy.ravel()
# Step 4: Interpolate
jy_interp_flat = griddata(source_points, source_values, target_points, method='linear')
# Step 5: Reshape interpolated values back to 2D
jyygr = jy_interp_flat.reshape(latgr.shape)  # shape (43, 43)





plt.imshow(jxxgr)
plt.show()

plt.imshow(jyygr)
plt.show()





# if keyword_set(arrowcol) then arrowc = arrowcol else 
arrowcol = 255
# ; give each dimension a name that makes sense
t_nn = 20
t_nm = 91
t_nl = 36
# t_nt = 1
t_nsel = 5

# nsel=8

# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD TEMPERATURE AND ALTITUDE
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

t_no_sets = 2
t_filename = '/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/smith2010-data/connie.060temp'


# ; print message saying that at least we are giving this a shot.
print('Loading.... ',t_filename)


t_connie_data=np.loadtxt(t_filename,skiprows = 1)


t_data = np.zeros((t_no_sets,t_nn,t_nm,t_nl))

count = 0
for m in range(t_nm):
   for n in range(t_nn):
      for k in range(t_no_sets):
         
         t_data[k,n,m,:] = t_connie_data[count,:]
         count = count + 1
         
print(t_filename,' Loaded')


temp = np.zeros((t_nm,t_nl+1))
alt = np.zeros((t_nm,t_nl+1))

temp[:,0:t_nl] = t_data[0,t_nsel,:,0:t_nl]
alt[:,0:t_nl] = t_data[1,t_nsel,:,0:t_nl]


temp[:,t_nl] = temp[:,0]
alt[:,t_nl] = alt[:,0]

temp = rotate4(temp)
alt = rotate4(alt)

print(t_nsel,np.nanmedian((alt[:,:6])))


# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD WINDS 
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -


w_nn = 20
w_nm = 91
w_nl = 36
w_nsel=t_nsel
# t_nt = 1

# nsel=8

# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD TEMPERATURE AND ALTITUDE
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

w_no_sets = 4
w_filename = '/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/smith2010-data/connie.060wind'


# ; print message saying that at least we are giving this a shot.
print('Loading.... ',w_filename)


w_connie_data=np.loadtxt(w_filename,skiprows = 1)


w_data = np.zeros((w_no_sets,w_nn,w_nm,w_nl))

count = 0
for m in range(w_nm):
   for n in range(w_nn):
      for k in range(w_no_sets):
         
         w_data[k,n,m,:] = w_connie_data[count,:]
         count = count + 1
         
print(w_filename,' Loaded')




vx = np.zeros((w_nm,w_nl+1))
vy = np.zeros((w_nm,w_nl+1))
vz = np.zeros((w_nm,w_nl+1))
vd = np.zeros((w_nm,w_nl+1))

vx[:,0:w_nl] = w_data[0,w_nsel,:,0:w_nl]
vy[:,0:w_nl] = w_data[1,w_nsel,:,0:w_nl]
vz[:,0:w_nl] = w_data[2,w_nsel,:,0:w_nl]
vd[:,0:w_nl] = w_data[3,w_nsel,:,0:w_nl]


vx[:,w_nl] = vx[:,0]
vy[:,w_nl] = vy[:,0]
vz[:,w_nl] = vz[:,0]
vd[:,w_nl] = vd[:,0]

vx = rotate4(vx)
vy = rotate4(vy)
vz = rotate4(vz)
vd = rotate4(vd)



# Remove zonal (longitudinal) mean — the original code seems to have placeholders
# The actual intended behavior is likely subtracting zonal mean for each latitude
vx_mean = np.mean(vx[0:w_nl,:], axis=0, keepdims=True)
vy_mean = np.mean(vy[0:w_nl,:], axis=0, keepdims=True)
vz_mean = np.mean(vz[0:w_nl,:], axis=0, keepdims=True)

# Subtract the zonal mean (replicate(1,w_nl+1) mimics broadcasting the mean)
vxper = vx - np.tile(vx_mean, (w_nl + 1,1))
vyper = vy - np.tile(vy_mean, (w_nl + 1,1))
vzper = vz - np.tile(vz_mean, (w_nl + 1,1))


temp_mean = np.mean(temp[0:w_nl,:], axis=0, keepdims=True)
tempper = temp - np.tile(temp_mean, (w_nl + 1,1))



# tempper = temp - replicate(1,nl+1)#(total(temp[0:nl-1,*],1)/float(nl))


lon1d = 360 * np.arange(t_nl + 1) / float(t_nl)
lat1d = 180 * (np.arange(t_nm) / float(t_nm - 1) - 0.5)

# Create 2D versions (like meshgrid-style)
lon2d = rotate4(np.tile(lon1d, (t_nm, 1)) )   # shape (t_nm, t_nl+1)
lat2d = rotate4(np.tile(lat1d[:, np.newaxis], (1, t_nl + 1)) ) # shape (t_nm, nl+1)

# Flatten original jxx/jyy and corresponding coordinates
lon_idx, lat_idx  = np.meshgrid(np.arange(vxper.shape[0]), np.arange(vxper.shape[1]))

lat_idx = rotate4(lat_idx *2)
lon_idx = rotate4(lon_idx*10)


# plt.imshow(vxper)
# plt.show()

# plt.imshow(vxper)
# plt.show()



# lon1d = 360 * np.arange(t_nl + 1) / float(t_nl)
# lat1d = 180 * (np.arange(t_nm) / float(t_nm - 1) - 0.5)

# # Create 2D versions (like meshgrid-style)
# lon2d = rotate4(np.tile(lon1d, (t_nm, 1)) )   # shape (t_nm, t_nl+1)
# lat2d = rotate4(np.tile(lat1d[:, np.newaxis], (1, t_nl + 1)) ) # shape (t_nm, t_nl+1)




# # ; we want to take the northern data and map the coordinates to
# # ; an x-y grid as viewed from above. then try contour plotting this

# # ; first shot - scale equal latitude spacing as equal radial spacing
# # ; there is probably a name for this sort of projection which i
# # ; could look up if i could be arsed!


# th2d = 90-lat2d
# ph2drad = lon2d*np.pi/180.0
     

# x = th2d*np.cos(ph2drad)
# y = th2d*np.sin(ph2drad)



vxx =  vxper*np.cos(ph2drad) - vyper*np.sin(ph2drad)
vyy =  vxper*np.sin(ph2drad) + vyper*np.cos(ph2drad)
vxxyytot = np.sqrt(vxx**2 + vyy**2)
# vxxsc = vxx/vxxyytot
# vyysc = vyy/vxxyytot
with np.errstate(divide='ignore', invalid='ignore'):
    vxxsc = np.divide(vxx, vxxyytot, where=vxxyytot!=0)
    vyysc = np.divide(vyy, vxxyytot, where=vxxyytot!=0)




# ; second shot - scale radial dimension and latitudinal component of arrows
# ; as if we were looking at the planet from above. To do this we need to
# ; scale the meridional component down by a factor of costheta

# ; - - - - - - this block of code scales arrows to same length and then
# ; applies scaling for orthographic projection. Arrows close to equator
# ; will look 'squashed' producing impression of spherical geometry

mit_nlat = 0.0


vxn = vxper[:,45:]
vyn = vyper[:,45:]
lon2dN = lon2d[:,45:]
lat2dN = lat2d[:,45:]

jtotn = np.sqrt(vxn**2 + vyn**2)

jsc = np.max(jtotn)
print('MAXIMUM total j for scaling = ',jsc)
vxnsc = vxn*np.cos(th2d[:,45:]*np.pi/180)/jsc
vynsc = vyn/jsc



vxx =  vxnsc*np.cos(ph2drad[:,45:]) - vynsc*np.sin(ph2drad[:,45:])
vyy =  vxnsc*np.sin(ph2drad[:,45:]) + vynsc*np.cos(ph2drad[:,45:])




source_values = vxx.ravel()
# Step 4: Interpolate
vx_interp_flat = griddata(source_points, source_values, target_points, method='linear')
# Step 5: Reshape interpolated values back to 2D
vxxgr = vx_interp_flat.reshape(latgr.shape)  # shape (43, 43)



source_values = vyy.ravel()
# Step 4: Interpolate
vy_interp_flat = griddata(source_points, source_values, target_points, method='linear')
# Step 5: Reshape interpolated values back to 2D
vyygr = vy_interp_flat.reshape(latgr.shape)  # shape (43, 43)


















# print (np.max(vxper),np.max(vyper))
# print (np.min(vxper),np.min(vyper))

# =============================================================================
# %% redo the above for temperature and winds for the higher level
# =============================================================================



# if keyword_set(arrowcol) then arrowc = arrowcol else 
arrowcol = 255
# ; give each dimension a name that makes sense
t_nn = 20
t_nm = 91
t_nl = 36
# t_nt = 1
t_nsel = 7

# nsel=8

# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD TEMPERATURE AND ALTITUDE
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

t_no_sets = 2
t_filename = '/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/smith2010-data/connie.060temp'


# ; print message saying that at least we are giving this a shot.
print('Loading.... ',t_filename)


# index = np.arange(t_no_sets)

# ; define the appropriate shape of array to take data
# t_nnns = t_nn*t_no_sets


t_connie_data=np.loadtxt(t_filename,skiprows = 1)


t_data = np.zeros((t_no_sets,t_nn,t_nm,t_nl))

count = 0
for m in range(t_nm):
   for n in range(t_nn):
      for k in range(t_no_sets):
         
         t_data[k,n,m,:] = t_connie_data[count,:]
         count = count + 1
         
print(t_filename,' Loaded')


temp = np.zeros((t_nm,t_nl+1))
alt = np.zeros((t_nm,t_nl+1))

temp[:,0:t_nl] = t_data[0,t_nsel,:,0:t_nl]
alt[:,0:t_nl] = t_data[1,t_nsel,:,0:t_nl]


temp[:,t_nl] = temp[:,0]
alt[:,t_nl] = alt[:,0]

temp = rotate4(temp)
alt = rotate4(alt)

print(t_nsel,np.nanmedian((alt[:,:6])))

# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD WINDS 
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -


# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD WINDS 
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -


w_nn = 20
w_nm = 91
w_nl = 36
w_nsel=t_nsel
# t_nt = 1

# nsel=8

# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# ; BLOCK TO LOAD TEMPERATURE AND ALTITUDE
# ; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

w_no_sets = 4
w_filename = '/Users/jcfq2/OneDrive - Northumbria University - Production Azure AD/python/jwst/saturn/smith2010-data/connie.060wind'


# ; print message saying that at least we are giving this a shot.
print('Loading.... ',w_filename)


# index = np.arange(w_no_sets)

# ; define the appropriate shape of array to take data
# w_nnns = w_nn*w_no_sets


w_connie_data=np.loadtxt(w_filename,skiprows = 1)


w_data = np.zeros((w_no_sets,w_nn,w_nm,w_nl))

count = 0
for m in range(w_nm):
   for n in range(w_nn):
      for k in range(w_no_sets):
         
         w_data[k,n,m,:] = w_connie_data[count,:]
         count = count + 1
         
print(w_filename,' Loaded')




vx = np.zeros((w_nm,w_nl+1))
vy = np.zeros((w_nm,w_nl+1))
vz = np.zeros((w_nm,w_nl+1))
vd = np.zeros((w_nm,w_nl+1))

vx[:,0:w_nl] = w_data[0,w_nsel,:,0:w_nl]
vy[:,0:w_nl] = w_data[1,w_nsel,:,0:w_nl]
vz[:,0:w_nl] = w_data[2,w_nsel,:,0:w_nl]
vd[:,0:w_nl] = w_data[3,w_nsel,:,0:w_nl]


vx[:,w_nl] = vx[:,0]
vy[:,w_nl] = vy[:,0]
vz[:,w_nl] = vz[:,0]
vd[:,w_nl] = vd[:,0]

vx = rotate4(vx)
vy = rotate4(vy)
vz = rotate4(vz)
vd = rotate4(vd)



# Remove zonal (longitudinal) mean — the original code seems to have placeholders
# The actual intended behavior is likely subtracting zonal mean for each latitude
vx_mean = np.mean(vx[0:w_nl,:], axis=0, keepdims=True)
vy_mean = np.mean(vy[0:w_nl,:], axis=0, keepdims=True)
vz_mean = np.mean(vz[0:w_nl,:], axis=0, keepdims=True)

# Subtract the zonal mean (replicate(1,w_nl+1) mimics broadcasting the mean)
vxper2 = vx - np.tile(vx_mean, (w_nl + 1,1))
vyper2 = vy - np.tile(vy_mean, (w_nl + 1,1))
vzper2 = vz - np.tile(vz_mean, (w_nl + 1,1))


temp_mean = np.mean(temp[0:w_nl,:], axis=0, keepdims=True)
tempper2 = temp - np.tile(temp_mean, (w_nl + 1,1))


# print (np.max(vxper2),np.max(vyper2))
# print (np.min(vxper2),np.min(vyper2))





vxx2 =  vxper2*np.cos(ph2drad) - vyper2*np.sin(ph2drad)
vyy2 =  vxper2*np.sin(ph2drad) + vyper2*np.cos(ph2drad)
vxx2yytot = np.sqrt(vxx2**2 + vyy2**2)
# vxx2sc = vxx2/vxx2yytot
# vyy2sc = vyy2/vxx2yytot
with np.errstate(divide='ignore', invalid='ignore'):
    vxx2sc = np.divide(vxx2, vxx2yytot, where=vxx2yytot!=0)
    vyy2sc = np.divide(vyy2, vxx2yytot, where=vxx2yytot!=0)



# plt.imshow(vxx2sc)
# plt.show()
# plt.imshow(vyy2sc)
# plt.show()


# ; second shot - scale radial dimension and latitudinal component of arrows
# ; as if we were looking at the planet from above. To do this we need to
# ; scale the meridional component down by a factor of costheta

# ; - - - - - - this block of code scales arrows to same length and then
# ; applies scaling for orthographic projection. Arrows close to equator
# ; will look 'squashed' producing impression of spherical geometry

mit_nlat = 0.0


vxn2 = vxper2[:,45:]
vyn2 = vyper2[:,45:]
lon2dN = lon2d[:,45:]
lat2dN = lat2d[:,45:]

jtotn = np.sqrt(vxn2**2 + vyn2**2)

jsc = np.max(jtotn)
print('MAXIMUM total j for scaling = ',jsc)
vxnsc2 = vxn2*np.cos(th2d[:,45:]*np.pi/180)/jsc
vynsc2 = vyn2/jsc



vxx2 =  vxnsc2*np.cos(ph2drad[:,45:]) - vynsc2*np.sin(ph2drad[:,45:])
vyy2 =  vxnsc2*np.sin(ph2drad[:,45:]) + vynsc2*np.cos(ph2drad[:,45:])




source_values = vxx2.ravel()
# Step 4: Interpolate
vx_interp_flat = griddata(source_points, source_values, target_points, method='linear')
# Step 5: Reshape interpolated values back to 2D
vxx2gr = vx_interp_flat.reshape(latgr.shape)  # shape (43, 43)



source_values = vyy2.ravel()
# Step 4: Interpolate
vy_interp_flat = griddata(source_points, source_values, target_points, method='linear')
# Step 5: Reshape interpolated values back to 2D
vyy2gr = vy_interp_flat.reshape(latgr.shape)  # shape (43, 43)




# plt.imshow(vxx2gr)
# plt.show()
# plt.imshow(vyy2gr)
# plt.show()
# plt.plot([0,1],[0,1])

# plt.show()


# =============================================================================
# %% Load the JWST fitted data
# =============================================================================




####### NB!!!! #################################################################
# since LOS wasn't working within the fits - we now fit without los and correct values here instead!!
####################################

temperature2=np.load('saturn_saves/saturn_temperature_2z.npy')
# pre_temperature2=np.load('saturn_saves/saturn_pretemperature_2.npy')
temperature_error2=np.load('saturn_saves/saturn_temperature_error_2z.npy')
density2=np.load('saturn_saves/saturn_density_2z.npy')
# pre_density2=np.load('saturn_saves/saturn_predensity_2.npy')
density_error2=np.load('saturn_saves/saturn_density_error_2z.npy')
totalE2=np.load('saturn_saves/saturn_totalE_2z.npy')
# intensity2=np.load('saturn_saves/saturn_intensity_2.npy')
# ch4_fun2=np.load('saturn_saves/saturn_ch4fun_2.npy')
ch4_hot2=np.load('saturn_saves/saturn_ch4hot_2z.npy')
# ch4_hot2=np.load('saturn_saves/saturn_ch4hot_2x.npy')
# 



# ####### load the LOS values ##########
naxis_maplos=np.load('saturn_spectra_los_shell_2.npy')
naxis_mapcount=np.load('saturn_spectra_count_2.npy')

nacca=naxis_mapcount[:,:,600]
nalla=naxis_maplos
scale=2


nalla[-90*scale:-45*scale,:]=nalla[-90*scale:-45*scale,:]+nalla[0:45*scale,:]
nacca[-90*scale:-45*scale,:]=nacca[-90*scale:-45*scale,:]+nacca[0:45*scale,:]
nalla[45*scale:90*scale,:]=nalla[45*scale:90*scale,:]+nalla[-45*scale:,:]
nacca[45*scale:90*scale,:]=nacca[45*scale:90*scale,:]+nacca[-45*scale:,:]



nalla=nalla[45*scale:-45*scale,:]
nacca=nacca[45*scale:-45*scale,:]


loscorr2=nalla/nacca
loscorr2=np.nan_to_num(loscorr2)


loscorr  =np.rot90(loscorr2,3)
# # loscorr[:,135*scale+1:]=loscorr2[:,0:90]



# 

scale=2

temperature=np.zeros([360*scale+1,180*scale+1])
# pre_temperature=np.zeros([360*scale+1,180*scale+1])
temperature_error=np.zeros([360*scale+1,180*scale+1])
density=np.zeros([360*scale+1,180*scale+1])
# pre_density=np.zeros([360*scale+1,180*scale+1])
density_error=np.zeros([360*scale+1,180*scale+1])
totalE=np.zeros([360*scale+1,180*scale+1])
# intensity=np.zeros([360*scale+1,180*scale+1])
# ch4_fun=np.zeros([360*scale+1,180*scale+1])
ch4_hot=np.zeros([360*scale+1,180*scale+1])

temperature[:,135*scale+1:]=temperature2[:,0:90]
temperature_error[:,135*scale+1:]=temperature_error2[:,0:90]
density[:,135*scale+1:]=density2[:,0:90]
density_error[:,135*scale+1:]=density_error2[:,0:90]
totalE[:,135*scale+1:]=totalE2[:,0:90]
ch4_hot[:,135*scale+1:]=ch4_hot2[:,0:90]

#

temperature= np.fliplr(temperature.T)
density= np.fliplr(density.T)
temperature_error= np.fliplr(temperature_error.T)
density_error= np.fliplr(density_error.T)
totalE= np.fliplr(totalE.T)
ch4_hot= np.fliplr(ch4_hot.T)
# density= np.fliplr(density.T)


import cartopy.crs as ccrs
crs = ccrs.RotatedPole(globe=ccrs.Globe(flattening=(0.0)))

 



temperature_diff=np.zeros_like(temperature)
temperature_median = np.median(temperature,axis=1)
for xix in range(360*scale): temperature_diff[:,xix]=temperature[:,xix]-temperature_median


density_diff=np.zeros_like(density)
density_median = np.median(density,axis=1)
for xix in range(360*scale): density_diff[:,xix]=density[:,xix]-density_median
density_diff=density_diff*loscorr


density2 = np.zeros_like(density)

for xix in range(360*scale): density2[:,xix]=density_diff[:,xix]+density_median



totalE_diff=np.zeros_like(totalE)
totalE_median = np.median(totalE,axis=1)
for xix in range(360*scale): totalE_diff[:,xix]=totalE[:,xix]-totalE_median


vyperB=vyper 

# =============================================================================
# =============================================================================
# %%  Now plot the figure
# =============================================================================
# =============================================================================

vyper=vyperB

crs = ccrs.NorthPolarStereo(globe=ccrs.Globe(flattening=(0.0)))


fig = plt.figure(figsize=(12,6),dpi=300)


bbox_props1= dict(boxstyle="round,pad=0.15", fc="whitesmoke", ec="silver", lw=2)

fig.subplots_adjust(hspace=0.1,wspace=0.25)




# =============================================================================
# # ---- obs temperature
# =============================================================================


ax = plt.subplot(231,projection=ccrs.NorthPolarStereo(central_longitude=180))

ax.set_extent([0, 360, 68, 90], ccrs.PlateCarree())
tp=ax.imshow(temperature_diff, origin='lower', extent=[-180,180,-90,90], transform=ccrs.PlateCarree(),cmap='seismic',vmin=-50,vmax=50)

# ax.text(360-315,65,'d', transform=ccrs.PlateCarree(), ha="center", va="center", weight='bold',bbox=bbox_props1)

cbar=fig.colorbar(tp,location='right',aspect=3)
cbar.ax.set_ylabel('JWST $\Delta_{long}$ Temperature [K]')
pos = cbar.ax.get_position()
ax1 = cbar.ax
# cmap = colormaps['mymap']
d_temp_pos=ax.get_position()


# set up an overlay to transpose the vectors in normal coordinates on top of our projection (i.e. cheat)
ax_overlay = fig.add_subplot(d_temp_pos, facecolor="none")
# ax_overlay.yaxis.set_label_position("right")
ax_overlay.tick_params(left=False, right=False, labelleft=False, labelright=False,
                bottom=False, labelbottom=False)

ax_overlay.quiver(xvv,yvv,-vely,-velx,scale=20,width=0.01,pivot='mid',color="xkcd:dark green")


gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                  linewidth=1, color='k', alpha=0.5, linestyle='dotted')
gl.ylocator = mticker.FixedLocator([40,50,60,70,80])
for longit in range (0,355,60): ax.text((360-longit+180) % 360,70,str(longit)+'$^{\circ}$W', transform=ccrs.PlateCarree(), ha="center", va="center",size=8)
for latit in range (70,82,10): ax.text(360-140,latit,str(latit)+'$^{\circ}$N', transform=ccrs.PlateCarree(), ha="center", va="center",size=8)

ax.text(360-230,64,'0$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
ax.text(360-50,63,'180$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="right", va="center",size=8,c='hotpink')
ax.text(360-329,66.5,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
# ax.text(360-330,69,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='mediumvioletred')
ax.text(360-144,67.7,'270$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
for drift in range(4): ax.plot([45+90*drift,45+90*drift],[90,40], transform=ccrs.PlateCarree(),c='hotpink',linestyle='dotted')
d_temp_pos=ax.get_position()


# =============================================================================
# # ---- obs denisty
# =============================================================================


ax = plt.subplot(234,projection=ccrs.NorthPolarStereo(central_longitude=180))

ax.set_extent([0, 360, 68, 90], ccrs.PlateCarree())
tp=ax.imshow(density_diff*loscorr, origin='lower', extent=[-180,180,-90,90], transform=ccrs.PlateCarree(),cmap='seismic',vmin=-1e15,vmax=1e15)

xposes=np.arange(19)*10-45
yposes=np.zeros_like(xposes)+72
yposes2=np.zeros_like(xposes)+78

ax.plot(xposes,yposes,c='k',linestyle='dashed', transform=ccrs.PlateCarree())
ax.plot(xposes,yposes2,c='k',linestyle='dashed', transform=ccrs.PlateCarree())
ax.plot([xposes[0],xposes[0]],[72,78],c='k',linestyle='dashed', transform=ccrs.PlateCarree())
ax.plot([xposes[-1],xposes[-1]],[72,78],c='k',linestyle='dashed', transform=ccrs.PlateCarree())


# ax.text(360-315,65,'e', transform=ccrs.PlateCarree(), ha="center", va="center", weight='bold',bbox=bbox_props1)

cbar=fig.colorbar(tp,location='right',aspect=4)
cbar.ax.set_ylabel('JWST $\Delta_{long}$ Density [/m$^2$]')

pos = cbar.ax.get_position()
ax2 = cbar.ax
# cmap = colormaps['mymap']


gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                  linewidth=1, color='k', alpha=0.5, linestyle='dotted')
gl.ylocator = mticker.FixedLocator([40,50,60,70,80])
for longit in range (0,355,60): ax.text((360-longit+180) % 360,70,str(longit)+'$^{\circ}$W', transform=ccrs.PlateCarree(), ha="center", va="center",size=8)
for latit in range (70,82,10): ax.text(360-140,latit,str(latit)+'$^{\circ}$N', transform=ccrs.PlateCarree(), ha="center", va="center",size=8)

ax.text(360-220,64,'0$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
ax.text(360-50,63,'180$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="right", va="center",size=8,c='hotpink')
ax.text(360-329,66.5,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
# ax.text(360-330,69,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='mediumvioletred')
ax.text(360-144,67.7,'270$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
for drift in range(4): ax.plot([45+90*drift,45+90*drift],[90,40], transform=ccrs.PlateCarree(),c='hotpink',linestyle='dotted')
d_den_pos=ax.get_position()
# ax.text(360-270,69,'~3 nbar, ~1300 km', transform=ccrs.PlateCarree(), ha="center", va="top", weight='bold',bbox=bbox_props1)



# =============================================================================
# # ---- model temnperature
# =============================================================================


# phaseshift=0

bbox_props1= dict(boxstyle="round,pad=0.15", fc="whitesmoke", ec="silver", lw=2)

# fig.subplots_adjust(hspace=0,wspace=0.2)


# ax = plt.subplot(121)
ax = plt.subplot(232,projection=ccrs.NorthPolarStereo(central_longitude=270))

ax.set_extent([0, 360, 68, 90], ccrs.PlateCarree())
# ax.set_extent([0, 360, 60, 90])


# tp=ax.imshow(rotate4(vxper), origin='lower', extent=[-180,180,-90,90] ,cmap='jet')
tp=ax.contourf(rotate4(lon_idx-180)+phaseshift,rotate4(lat_idx)-90,rotate4(-tempper), origin='lower', extent=[-180,180,-90,90], transform=ccrs.PlateCarree(),cmap='seismic',levels=15)


cbar=fig.colorbar(tp,location='right',aspect=4)
cbar.ax.set_ylabel('Model $\Delta_{long}$ Temperature [K]')

gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                    linewidth=1, color='k', alpha=0.5, linestyle='dotted')
gl.ylocator = mticker.FixedLocator([40,50,60,70,80])

# 



# =============================================================================
# # Set up an interpolation grid over the pole using cartopy
# =============================================================================


# # Saturn mean radius in meters (~58,232 km)
# saturn_radius = 58232000.0  # meters

# # Custom Saturn-compatible North Polar Stereographic projection
# saturn_north_pole = ccrs.Stereographic(central_latitude=90,
#                                        central_longitude=0,
#                                        true_scale_latitude=90,
#                                        globe=ccrs.Globe(semimajor_axis=saturn_radius,
#                                                         semiminor_axis=saturn_radius,
#                                                         ellipse=None))


# # # Define projection for transforming
# proj = ccrs.NorthPolarStereo()
# transformer = ccrs.PlateCarree()

# extent = 0.5 * saturn_radius  # ~5.8 million meters
# nx, ny = 22, 22
# x = np.linspace(-extent, extent, nx)
# y = np.linspace(-extent, extent, ny)
# xgrid, ygrid = np.meshgrid(x, y)

# # Use PlateCarree for Saturn as well (no Earth-specific assumptions)
# saturn_plate_carree = ccrs.PlateCarree(globe=ccrs.Globe(semimajor_axis=saturn_radius,
#                                                         semiminor_axis=saturn_radius,
#                                                         ellipse=None))

# # Transform from projected x/y to lon/lat
# lonlat = saturn_plate_carree.transform_points(saturn_north_pole, xgrid, ygrid)
# lon_grid = lonlat[..., 0]
# lat_grid = lonlat[..., 1]


# # Flatten the 2D lat/lon coordinate arrays and corresponding data values
# points = np.column_stack((lat_idx.ravel()-90, 180-lon_idx.ravel()))  # shape (n_points, 2)

# # Now form the query points: the lat/lon locations where you want interpolated values
# query_points = np.column_stack((lat_grid.ravel(), lon_grid.ravel()))  # shape (n_queries, 2)


# # =============================================================================
# # From here we can input whatever values we want and find the interpolated values
# # =============================================================================

# # Interpolate using 'linear' (can also try 'nearest' or 'cubic')
# values = np.flipud(vyper).ravel()  # shape (n_points,)
# vyper_iv = griddata(points, values, query_points, method='linear')
# vyper_iv = vyper_iv.reshape(lat_grid.shape)

# values = np.flipud(-vxper).ravel()  # shape (n_points,)
# vxper_iv = griddata(points, values, query_points, method='linear')
# vxper_iv = vxper_iv.reshape(lat_grid.shape)

# values = vzper.ravel()  # shape (n_points,)
# vzper_iv = griddata(points, values, query_points, method='linear')
# vzper_iv = vxper_iv.reshape(lat_grid.shape)


# =============================================================================
# Continue plotting
# =============================================================================





# set up an overlay to transpose the vectors in normal coordinates on top of our projection (i.e. cheat)
d_temp_pos=ax.get_position()


ax_overlay = fig.add_subplot(d_temp_pos, facecolor="none")

ax_overlay.set_xlim(-0.5,0.5)
ax_overlay.set_ylim(-0.5,0.5)


x_center = 0
y_center = 0
# Rotation angle in radians
theta = np.radians(phaseshift)
 
# Translate positions to origin
x_translated = xgr - x_center
y_translated = ygr - y_center
 
# Rotate positions
x_rotated = x_translated * np.cos(theta) - y_translated * np.sin(theta)
y_rotated = x_translated * np.sin(theta) + y_translated * np.cos(theta)
 
# Translate back
xgr_rot = x_rotated + x_center
ygr_rot = y_rotated + y_center
 
# Step 3: Rotate the vector directions
 
# This is a pure rotation, since vectors are directional only — no need to translate.
 
vxxgr_rot = vxxgr * np.cos(theta) - vyygr * np.sin(theta)
vyygr_rot = vxxgr * np.sin(theta) + vyygr * np.cos(theta)



ax_overlay.quiver(xgr_rot.ravel(),ygr_rot.ravel(),(vxxgr_rot).ravel(),(vyygr_rot).ravel(),scale=15,width=0.0075,pivot='mid',color="xkcd:pale green")


# ax_overlay.yaxis.set_label_position("right")
ax_overlay.tick_params(left=False, right=False, labelleft=False, labelright=False,
                bottom=False, labelbottom=False)

# ax_overlay.quiver(xgr,ygr,jyygr,-jxxgr,scale=20,width=0.01,pivot='mid',color="xkcd:black")

# xgr, ygr









# the uninterpolated arrows
# ax.quiver(360-lon_idx+phaseshift,lat_idx-90, np.flipud(vyper),np.flipud(-vxper), transform=ccrs.PlateCarree(),scale=3e2,width=0.01,pivot='mid',color='black')

# the interpolated grid positions (for checking)
# ax.plot(lon_grid, lat_grid, 'x',transform=ccrs.PlateCarree())

# the interpolated arrows
# ax.quiver(lon_grid+phaseshift,lat_grid, -vyper_iv,-vxper_iv, transform=ccrs.PlateCarree(),scale=2e2,width=0.009,color='xkcd:pale green',pivot='mid')


ax.text(360-215,66.5,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
ax.text(360-45,65,'270$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="right", va="center",size=8,c='hotpink')
ax.text(360-315,66.5,'180$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
# ax.text(360-330,69,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='mediumvioletred')
ax.text(360-140,63,'0$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
for drift in range(4): ax.plot([45+90*drift,45+90*drift],[90,40], transform=ccrs.PlateCarree(),c='hotpink',linestyle='dotted')

ax.text(360-270,69,'~8 nbar, ~1100 km', transform=ccrs.PlateCarree(), ha="center", va="top", weight='bold',bbox=bbox_props1)


# =============================================================================
# # ---- model divergence
# =============================================================================


ax = plt.subplot(212,projection=ccrs.NorthPolarStereo(central_longitude=270))
ax.set_extent([0, 360, 68, 90], ccrs.PlateCarree())

tp=ax.contourf(rotate4(lon_idx-180)+phaseshift,rotate4(lat_idx)-90,rotate4(-jzper), origin='lower', extent=[-180,180,-90,90], transform=ccrs.PlateCarree(),cmap='seismic_r',levels=15,vmin=-0.3,vmax=0.3)


cbar=fig.colorbar(tp,location='right',aspect=4)
cbar.ax.set_ylabel(r'Model $\Delta_{long}$ Divergence [nT/m$^2$]')

gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                    linewidth=1, color='k', alpha=0.5, linestyle='dotted')
gl.ylocator = mticker.FixedLocator([40,50,60,70,80])

# ax.quiver(360-lon_idx+phaseshift,lat_idx-90, np.flipud(jyper),np.flipud(-jxper), transform=ccrs.PlateCarree(),scale=2e7,width=0.01,pivot='mid',color='black')




# # Interpolate using 'linear' (can also try 'nearest' or 'cubic')
# values = np.flipud(jyper).ravel()  # shape (n_points,)
# jyper_iv = griddata(points, values, query_points, method='linear')
# jyper_iv = jyper_iv.reshape(lat_grid.shape)

# values = np.flipud(-jxper).ravel()  # shape (n_points,)
# jxper_iv = griddata(points, values, query_points, method='linear')
# jxper_iv = jxper_iv.reshape(lat_grid.shape)


# the interpolatation grid (for checking)
# ax.plot(lon_grid, lat_grid, 'x',transform=ccrs.PlateCarree())

# the interpolated arrows
# ax.quiver(lon_grid+phaseshift,lat_grid, -jyper_iv,-jxper_iv, transform=ccrs.PlateCarree(),scale=1.5e7,width=0.009,color='xkcd:very pale blue',pivot='mid')



# set up an overlay to transpose the vectors in normal coordinates on top of our projection (i.e. cheat)
d_temp_pos=ax.get_position()


ax_overlay = fig.add_subplot(d_temp_pos, facecolor="none")

ax_overlay.set_xlim(-0.5,0.5)
ax_overlay.set_ylim(-0.5,0.5)


# x_center = 0
# y_center = 0
# # Rotation angle in radians
# theta = np.radians(phaseshift)
 
# # Translate positions to origin
# x_translated = xgr - x_center
# y_translated = ygr - y_center
 
# # Rotate positions
# x_rotated = x_translated * np.cos(theta) - y_translated * np.sin(theta)
# y_rotated = x_translated * np.sin(theta) + y_translated * np.cos(theta)
 
# # Translate back
# xgr_rot = x_rotated + x_center
# ygr_rot = y_rotated + y_center
 
# Step 3: Rotate the vector directions
 
# This is a pure rotation, since vectors are directional only — no need to translate.
 
jxxgr_rot = jxxgr * np.cos(theta) - jyygr * np.sin(theta)
jyygr_rot = jxxgr * np.sin(theta) + jyygr * np.cos(theta)



ax_overlay.quiver(xgr_rot.ravel(),ygr_rot.ravel(),(jxxgr_rot).ravel(),(jyygr_rot).ravel(),scale=15,width=0.0075,pivot='mid',color="xkcd:very pale blue")


# ax_overlay.yaxis.set_label_position("right")
ax_overlay.tick_params(left=False, right=False, labelleft=False, labelright=False,
                bottom=False, labelbottom=False)

# ax_overlay.quiver(xgr,ygr,jyygr,-jxxgr,scale=20,width=0.01,pivot='mid',color="xkcd:black")

# xgr, ygr



ax.text(360-215,66.5,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
ax.text(360-45,65,'270$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="right", va="center",size=8,c='hotpink')
ax.text(360-315,66.5,'180$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
# ax.text(360-330,69,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='mediumvioletred')
ax.text(360-140,63,'0$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
for drift in range(4): ax.plot([45+90*drift,45+90*drift],[90,40], transform=ccrs.PlateCarree(),c='hotpink',linestyle='dotted')






ax.plot(xposes+90,yposes,c='k',linestyle='dashed', transform=ccrs.PlateCarree())
ax.plot(xposes+90,yposes2,c='k',linestyle='dashed', transform=ccrs.PlateCarree())
ax.plot([xposes[0]+90,xposes[0]+90],[72,78],c='k',linestyle='dashed', transform=ccrs.PlateCarree())
ax.plot([xposes[-1]+90,xposes[-1]+90],[72,78],c='k',linestyle='dashed', transform=ccrs.PlateCarree())



#  now add in the winds and temperatures from a higher altitude


# =============================================================================
# # ---- model temnperature pressure level 8
# # =============================================================================


phaseshift=-90-35


bbox_props1= dict(boxstyle="round,pad=0.15", fc="whitesmoke", ec="silver", lw=2)


ax = plt.subplot(233,projection=ccrs.NorthPolarStereo(central_longitude=270))

ax.set_extent([0, 360, 68, 90], ccrs.PlateCarree())


# tp=ax.imshow(rotate4(vxper), origin='lower', extent=[-180,180,-90,90] ,cmap='jet')
tp=ax.contourf(rotate4(lon_idx-180)+phaseshift,rotate4(lat_idx)-90,rotate4(-tempper2), origin='lower', extent=[-180,180,-90,90], transform=ccrs.PlateCarree(),cmap='seismic',levels=15)


cbar=fig.colorbar(tp,location='right',aspect=4)
cbar.ax.set_ylabel('Model $\Delta_{long}$ Temperature [K]')

gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                    linewidth=1, color='k', alpha=0.5, linestyle='dotted')
gl.ylocator = mticker.FixedLocator([40,50,60,70,80])

# 

# from scipy.interpolate import griddata


# from scipy.interpolate import RegularGridInterpolator

# # Saturn mean radius in meters (~58,232 km)
# saturn_radius = 58232000.0  # meters

# # Custom Saturn-compatible North Polar Stereographic projection
# saturn_north_pole = ccrs.Stereographic(central_latitude=90,
#                                         central_longitude=0,
#                                         true_scale_latitude=90,
#                                         globe=ccrs.Globe(semimajor_axis=saturn_radius,
#                                                         semiminor_axis=saturn_radius,
#                                                         ellipse=None))

# # =============================================================================
# # # Interpolate using 'linear' (can also try 'nearest' or 'cubic')
# # =============================================================================
# values = np.flipud(vyper2).ravel()  # shape (n_points,)
# vyper_iv2 = griddata(points, values, query_points, method='linear')
# vyper_iv2 = vyper_iv2.reshape(lat_grid.shape)

# values = np.flipud(-vxper2).ravel()  # shape (n_points,)
# vxper_iv2 = griddata(points, values, query_points, method='linear')
# vxper_iv2 = vxper_iv2.reshape(lat_grid.shape)



# =============================================================================
# # ploting arrows
# =============================================================================




# set up an overlay to transpose the vectors in normal coordinates on top of our projection (i.e. cheat)
d_temp_pos=ax.get_position()


ax_overlay = fig.add_subplot(d_temp_pos, facecolor="none")

ax_overlay.set_xlim(-0.5,0.5)
ax_overlay.set_ylim(-0.5,0.5)

 
# Step 3: Rotate the vector directions
 
# This is a pure rotation, since vectors are directional only — no need to translate.
 
vxx2gr_rot = vxx2gr * np.cos(theta) - vyy2gr * np.sin(theta)
vyy2gr_rot = vxx2gr * np.sin(theta) + vyy2gr * np.cos(theta)



ax_overlay.quiver(xgr_rot.ravel(),ygr_rot.ravel(),(vxx2gr_rot).ravel(),(vyy2gr_rot).ravel(),scale=15,width=0.0075,pivot='mid',color="xkcd:pale green")




# ax_overlay.yaxis.set_label_position("right")
ax_overlay.tick_params(left=False, right=False, labelleft=False, labelright=False,
                bottom=False, labelbottom=False)

# ax_overlay.quiver(xgr,ygr,jyygr,-jxxgr,scale=20,width=0.01,pivot='mid',color="xkcd:black")

# xgr, ygr




# ax.plot(lon_grid, lat_grid, 'x',transform=ccrs.PlateCarree())

# the uninterpolated arrows
# ax.quiver(360-lon_idx+phaseshift,lat_idx-90, np.flipud(vyper2),np.flipud(-vxper2), transform=ccrs.PlateCarree(),scale=3e2,width=0.01,pivot='mid',color='black')

# the interpolated arrows
# ax.quiver(lon_grid+phaseshift,lat_grid, -vyper_iv2,-vxper_iv2, transform=ccrs.PlateCarree(),scale=2e2,width=0.009,color='xkcd:pale green',pivot='mid')


ax.text(360-215,66.5,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
ax.text(360-45,65,'270$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="right", va="center",size=8,c='hotpink')
ax.text(360-315,66.5,'180$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
# ax.text(360-330,69,'90$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='mediumvioletred')
ax.text(360-140,63,'0$^{\circ}\Psi_N$', transform=ccrs.PlateCarree(), ha="left", va="center",size=8,c='hotpink')
for drift in range(4): ax.plot([45+90*drift,45+90*drift],[90,40], transform=ccrs.PlateCarree(),c='hotpink',linestyle='dotted')


ax.text(360-270,69,'~3 nbar, ~1300 km', transform=ccrs.PlateCarree(), ha="center", va="top", weight='bold',bbox=bbox_props1)

fig.savefig('asd_temper_fig2d.pdf', dpi=300, bbox_inches='tight', facecolor='white', pad_inches=0) 

plt.show()

# %%


# fig = plt.figure(figsize=(12,6),dpi=300)

# ax_overlay = fig.add_subplot(111, facecolor="none")
# # ax_overlay.yaxis.set_label_position("right")
# # ax_overlay.tick_params(left=False, right=False, labelleft=False, labelright=False,
#                 # bottom=False, labelbottom=False)

# ax_overlay.quiver(xgr.ravel(),ygr.ravel(),(jxxgr).ravel(),(jyygr).ravel(),scale=35,width=0.002,pivot='mid',color="xkcd:black")
# ax_overlay.set_xlim(-0.5,0.5)
# ax_overlay.set_ylim(-0.5,0.5)